#include "cell.h"
#include "comm.h"
#include "force_field.h"
// #include "verlet.h"
#include "shake.h"
#include "minimize.h"
#include "memory.h"
#include "memory_cpp.hpp"
#include "bonded.h"
int post_alpha_shift(ff_t *ff, mpp_t *mpp) {
  int need_rebuild = cell_check(ff->grid);
  comm_allreduce(&need_rebuild, 1, MPI_INT, MPI_LOR, mpp);
  if (need_rebuild) {
    cell_export(ff->grid);
    forward_comm_export_list_cg(ff->grid, mpp);
    cell_import(ff->grid);
    forward_comm_most(ff->grid, mpp);
    if (ff->present & M_BOND) {
      find_bonds(ff->grid);
      //forward_comm_bond(ff->grid, mpp);
    }
  }
  return need_rebuild;
}
void energy_force(minimizer_t *mz) {
  force(mz->ff);
  reverse_comm_f(mz->ff->grid, mz->mpp);
  mdstat_t *stat = &mz->ff->stat;
  comm_allreduce_stat(stat, mz->mpp);

  real etot = stat->ecoul + stat->evdwl + stat->ebond + stat->eangle + stat->etori + stat->eimpr;
  mz->e = etot;
#warning "Why here is 7051 instead of atom count?"
  vec<real> *plain_f = esmd::allocate<vec<real>>(7051, "plain f");
  FOREACH_LOCAL_CELL(mz->ff->grid, ii, jj, kk, cell) {
    for (int i = 0; i < cell->natom; i++) {
      veccpy(plain_f[cell->tag[i]], cell->f[i]);
    }
  }
  esmd::deallocate(plain_f);
}
void alpha_step(minimizer_t *mz, real alpha_new) {
  // printf("Trying alpha=%f\n", alpha_new);
  cellgrid_t *grid = mz->ff->grid;
  real dalpha = alpha_new - mz->alpha_last;
  mz->alpha_last = alpha_new;
  FOREACH_LOCAL_CELL(grid, cx, cy, cz, cell) {
    for (int i = 0; i < cell->natom; i++) {
      vecscaleaddv(cell->x[i], cell->x[i], cell->cg_h[i], 1, dalpha);
    }
  }
  // if (mz->rigid) {
  //   correct_coordinates(grid, mz->mpp);
  // }
  if (!post_alpha_shift(mz->ff, mz->mpp)) {
    forward_comm_x(grid, mz->mpp);
  }
  mz->neval++;
  energy_force(mz);
}
// #define max(a, b) ((a) > (b) ? (a):(b))
// #define min(a, b) ((a) < (b) ? (a):(b))
int linemin_quadratic(minimizer_t *mz, real eorig) {
  mz->alpha_last = 0;
  cellgrid_t *grid = mz->ff->grid;
  mpp_t *mpp = mz->mpp;
  real fhorig = 0, hmax = -1e8;
  FOREACH_LOCAL_CELL(grid, cx, cy, cz, cell) {
    for (int i = 0; i < cell->natom; i++) {
      fhorig += cell->f[i].x * cell->cg_h[i].x;
      fhorig += cell->f[i].y * cell->cg_h[i].y;
      fhorig += cell->f[i].z * cell->cg_h[i].z;
      hmax = max(hmax, cell->cg_h[i].x);
      hmax = max(hmax, cell->cg_h[i].y);
      hmax = max(hmax, cell->cg_h[i].z);
    }
  }
  comm_allreduce(&fhorig, 1, mpi_real, MPI_SUM, mpp);
  comm_allreduce(&hmax, 1, mpi_real, MPI_MAX, mpp);

  if (hmax == 0) {
    return CG_FORCEZERO;
  }

  real alphamax = min(ALPHAMAX, DMAX / hmax);
  // printf("hmax=%f, alphamax=%f\n", hmax, alphamax);
  real alpha = alphamax;
  real fh0 = fhorig;
  real e0 = eorig;
  real alpha0 = 0;
  real alphalast = 0;
  while (1) {
    alpha_step(mz, alpha);
    real f2 = 0, fh = 0;
    FOREACH_LOCAL_CELL(grid, cx, cy, cz, cell) {
      for (int i = 0; i < cell->natom; i++) {
        fh += vecdot(cell->f[i], cell->cg_h[i]);
        f2 += vecdot(cell->f[i], cell->f[i]);
      }
    }
    comm_vallreduce(2, mpi_real, MPI_SUM, mpp, &fh, &f2);
    real delfh = fh - fh0;
    if (fabs(fh) < EPS_QUAD || fabs(delfh) < EPS_QUAD) {
      alpha_step(mz, 0);
      return CG_ZEROQUAD;
    }
    real relerr = fabs(1.0 - (0.5 * (alpha - alpha0) * (fh + fh0) + mz->e) / (eorig));
    alpha0 = alpha - (alpha - alpha0) * fh / delfh;
    if (relerr < QTOL && alpha0 > 0 && alpha0 < alphamax) {
      alpha_step(mz, alpha0);
      if (mz->e - eorig < EMACH) {
        return 0;
      }
    }
    real de_ideal = -BACKSLOPE * alpha * fh;
    real de = mz->e - eorig;
    if (de <= de_ideal) {
      return 0;
    }
    fh0 = fh;
    e0 = mz->e;
    alpha0 = alpha;
    alpha *= ALPHA_REDUCE;
    if (alpha <= 0 || de_ideal >= -EMACH) {
      alpha_step(mz, 0);
      return CG_ZEROALPHA;
    }
  }
}

int min_cg(minimizer_t *mz, int niter) {
  // printf("rigid=%d\n", mz->rigid);
  cellgrid_t *grid = mz->ff->grid;
  if (mz->rigid) {
    // correct_coordinates(grid, mz->mpp);
  }
  energy_force(mz);
  real gg = 0;
  FOREACH_LOCAL_CELL(grid, cx, cy, cz, cell) {
    for (int i = 0; i < cell->natom; i++) {
      veccpy(cell->cg_g[i], cell->f[i]);
      veccpy(cell->cg_h[i], cell->f[i]);
      gg += vecdot(cell->f[i], cell->f[i]);
    }
  }
  comm_allreduce(&gg, 1, mpi_real, MPI_SUM, mz->mpp);
  for (int t = 0; t < niter; t++) {
    mdstat_t *stat = &mz->ff->stat;
    
    if (mz->rigid) {
      // correct_coordinates(grid, mz->mpp);
      energy_force(mz);
    }

    real e0 = mz->e;
    if (mz->mpp->pid == 0)
      printf("niter=%d, etot=%f, ecoul=%f, evdw=%f, ebond=%f, eangle=%f, etori=%f, eimpr=%f\n", t, mz->e, stat->ecoul, stat->evdwl, stat->ebond, stat->eangle, stat->etori, stat->eimpr);
    int fail = linemin_quadratic(mz, mz->e);
    if (fail)
      return fail;
    if (mz->neval > mz->maxeval) {
      return CG_MAXEVAL;
    }
    if (fabs(mz->e - e0) < mz->etol * 0.5 * (fabs(mz->e) + fabs(e0) + 1e-8))
      return CG_ETOL;

    real ff = 0, fg = 0;
    FOREACH_LOCAL_CELL(grid, cx, cy, cz, cell) {
      for (int i = 0; i < cell->natom; i++) {
        ff += vecdot(cell->f[i], cell->f[i]);
        fg += vecdot(cell->f[i], cell->cg_g[i]);
      }
    }
    comm_vallreduce(2, mpi_real, MPI_SUM, mz->mpp, &ff, &fg);
    if (ff < mz->ftol * mz->ftol)
      return CG_FTOL;
    real beta = max(0, (ff - fg) / gg);
    real gh = 0;
    FOREACH_LOCAL_CELL(grid, cx, cy, cz, cell) {
      for (int i = 0; i < cell->natom; i++) {
        veccpy(cell->cg_g[i], cell->f[i]);
        vecscaleaddv(cell->cg_h[i], cell->cg_g[i], cell->cg_h[i], 1, beta);
        gh += vecdot(cell->cg_g[i], cell->cg_h[i]);
      }
    }
    comm_allreduce(&gh, 1, mpi_real, MPI_SUM, mz->mpp);
    if (gh <= 0) {
      FOREACH_LOCAL_CELL(grid, cx, cy, cz, cell) {
        for (int i = 0; i < cell->natom; i++) {
          veccpy(cell->cg_h[i], cell->cg_g[i]);
        }
      }
    }
  }
}
void init_minimizer(minimizer_t *mz, int maxeval, real etol, real ftol, ff_t *ff, mpp_t *mpp) {
  mz->neval = 0;
  mz->maxeval = maxeval;
  mz->etol = etol;
  mz->ftol = ftol;
  mz->ff = ff;
  mz->mpp = mpp;
}
